from peft import PeftModel
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import json
import pickle
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--size', type=str, help="This argument indicate size of model like 7b, 13b")
parser.add_argument('--mode', type=str, help="This argumnet indicate category of prompt like no-constraint, priming")
parser.add_argument('--quan', type=str, help="This argumnet indicate quantization option")
parser.add_argument('--k', type=int, help="This argumnet indicate how many top- or low- k, do you want.")
parser.add_argument('--mul', type=float, help="This argumnet indicate multiplying scale")
parser.add_argument('--top', type=str, default="t", help="This argument inidicate top or low mode")
parser.add_argument('--ft', type=str, default="f", help="This argument indicate if you using finetuning model")
args = parser.parse_args()

def get_topk(path, k, largest:bool):
    f = open(path,'r')
    result = []
    while True:
        line = f.readline().strip()
        if not line : break
        try:
            num = float(line[-6:])
            result.append(num)
        except:
            pass
    v,idx = torch.topk(torch.tensor(result),k,largest=largest)
    f.close()

    return idx

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
    
args.ft = str2bool(args.ft)
args.top = str2bool(args.top)
length_index = []
num_decoders = 0

att_out_count = 0
att_decoder_len = 0

def att_hook(module, input, output):
    global att_out_count
    global att_decoder_len
    global num_decoders
    global length_index
    
    att_out_count+=1
    layer_idx = att_out_count%num_decoders
    if layer_idx==1 : att_decoder_len+=1

    if att_decoder_len>1:
        if layer_idx==2:
            if output.shape[1]==1: #and att_decoder_len==1:
                for idx in length_index:
                    for b in range(output.shape[0]):
                        output.data[b,0,idx] = output.data[b,0,idx]*(args.mul)
    
    return output

def load_prompts_from_json(file_path: str):
    with open(file_path, 'r') as file:
        data = [json.loads(line.strip()) for line in file]
        return data

def inference(
        model:AutoModelForCausalLM,
        tokenizer:AutoTokenizer,
        prompts:list[str],
        batch_size:int = 16,
        **kwargs,
) -> list[str]:
    global att_decoder_len

    model.eval()
    generated_texts = []
    with torch.no_grad():
        from tqdm import tqdm
        for i in tqdm(range(0, len(prompts), batch_size)):
            batch = prompts[i:i+batch_size]
            generated_texts += process_batch(model,tokenizer,batch,**kwargs)
            att_decoder_len = 0

    generated_texts = [gen_text.replace(prompt,"").replace("$}}%","") \
                       for gen_text,prompt in zip(generated_texts,prompts)]

    return generated_texts

def process_batch(
        model:AutoModelForCausalLM,
        tokenizer:AutoTokenizer,
        batch:list[str],
        **kwargs,
) -> list[str]:
    model_inputs = tokenizer(batch, return_tensors="pt", padding=True).to(model.device)
    try:
        model_outputs = model.generate(**model_inputs,**kwargs)
        generated_texts = tokenizer.batch_decode(model_outputs, skip_special_tokens=True)
    except KeyboardInterrupt as ke:
        print(ke)
        exit()
    except RuntimeError as re:
        print(re)
        if "CUDA" in str(re):
            import gc
            del model_inputs
            gc.collect()
            torch.cuda.empty_cache()

            temp_batch_size = len(batch)//2
            print("temp_batch_size:",temp_batch_size)

            temp_batch_1 = batch[:temp_batch_size]
            generated_text_1 = process_batch(model,tokenizer,temp_batch_1,**kwargs)

            temp_batch_2 = batch[temp_batch_size:]
            generated_text_2 = process_batch(model,tokenizer,temp_batch_2,**kwargs)

            generated_texts = generated_text_1 + generated_text_2

    return generated_texts

def get_NumDecoders(model):
    name_list = [i for i,_ in model.named_parameters()]
    return int(name_list[-3].split(".")[2])+1

def main(
        batch_size:int = 50,
):
    # get indeices of top or smallest-k unit
    global length_index
    length_index = get_topk(f'./tracking_result/4bit_{args.mode}_att_llama2_{args.size}_second.txt', args.k, args.top)


    model_name = f"meta-llama/Llama-2-{args.size}-chat-hf"

    # set quntization option
    if args.quan=="4bit":
        nf4_config = BitsAndBytesConfig(
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    elif args.quan=="8bit":
        nf4_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
    
    # set model
    if args.quan=="full":
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            cache_dir="/data/huggingface_models/",
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            quantization_config=nf4_config,
            cache_dir="/data/huggingface_models/",
        )
    model.config.pad_token_id = model.config.eos_token_id

    # get number of hidden unit in model
    global num_decoders
    num_decoders = get_NumDecoders(model)

    # hook attention outputs
    if args.ft:
        # if finetuning model, load peftmodel
        print("fine tuning mode")
        if args.mode=="zero":
            lora_path = "./pretrains/lora_llama/models/zero-shot-10000/"
        elif args.mode=="len":
            lora_path = "./pretrains/lora_llama/models/len0-10000/"
        elif args.mode=="prim":
            lora_path = "./pretrains/lora_llama/models/len2-10000/"
        model = PeftModel.from_pretrained(model, lora_path)

        for _, layer in enumerate(model.model.model.layers):
            layer.self_attn.o_proj.register_forward_hook(att_hook)
    else :
        for _, layer in enumerate(model.model.layers):
            layer.self_attn.o_proj.register_forward_hook(att_hook)

    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=True,
        padding_side="left",
        cache_dir="/data/huggingface_models/"
    )
    tokenizer.pad_token = tokenizer.eos_token


    
    # select prompt mode for loading data
    json_data = load_prompts_from_json("./train-test-valid/google_test.jsonl")
    if args.mode=="zero":
        print("="*50+"\nload zeroshot data")
        train_prompts = [f"Sentece:\n{i['text']}\nThe sentence without the less important words would be:\n" for i in json_data]
    elif args.mode=="len":
        print("="*50+"\nload length constraint data")
        train_prompts = [f"Sentece:\n{i['text']}\nThe sentence without the less important {len(i['text'].split())-len(i['summaries'][0].split())} words would be:\n" for i in json_data]
    elif args.mode=="prim":
        print("="*50+"\nload priming data")
        train_prompts = [f"Sentence that consists of {len(i['text'].split())} words:\n{i['text']}\nThe sentence that consists of {len(i['summaries'][0].split())} words without the less important {len(i['text'].split())-len(i['summaries'][0].split())} words would be:\n" for i in json_data]
    else:
        assert False, "Wrong args \'--mode\'"
    del(json_data)


    # inference data
    generated_text = inference(
        model=model,
        tokenizer=tokenizer,
        prompts=train_prompts,
        batch_size=batch_size,
        max_new_tokens=150,
        do_sample=False,
    )
    
    if args.top:
        top = "top"
    else:
        top = "low"

    text_name = f"./generated_text/edit100/4bit_13b_ft/{args.mode}_{args.size}_{top}{args.k}_{args.mul}.pkl"
    with open(text_name,'wb') as f:
        pickle.dump(generated_text,f)
    
if __name__ == "__main__":
    main()